A deep learning system for heart failure mortality prediction

Heart failure (HF) is the final stage of the various heart diseases developing. The mortality rates of prognosis HF patients are highly variable, ranging from 5% to 75%. Evaluating the all-cause mortality of HF patients is an important means to avoid death and positively affect the health of patients. But in fact, machine learning models are difficult to gain good results on missing values, high dimensions, and imbalances HF data. Therefore, a deep learning system is proposed. In this system, we propose an indicator vector to indicate whether the value is true or be padded, which fast solves the missing values and helps expand data dimensions. Then, we use a convolutional neural network with different kernel sizes to obtain the features information. And a multi-head self-attention mechanism is applied to gain whole channel information, which is essential for the system to improve performance. Besides, the focal loss function is introduced to deal with the imbalanced problem better. The experimental data of the system are from the public database MIMIC-III, containing valid data for 10311 patients. The proposed system effectively and fast predicts four death types: death within 30 days, death within 180 days, death within 365 days and death after 365 days. Our study uses Deep SHAP to interpret the deep learning model and obtains the top 15 characteristics. These characteristics further confirm the effectiveness and rationality of the system and help provide a better medical service.


Introduction
Heart failure (HF) is a condition that causes structural or functional abnormalities of the heart through a variety of causes, resulting in dysfunction of the ventricular systolic or diastolic functions [1]. It is the final development stage of various heart diseases [2]. According to the American College of Cardiology, cardiovascular disease causes one-third of the world's death. More than five million people in the United States suffer from heart failure, and 550,000 new cases are diagnosed each year [3][4][5]. Meanwhile, in China the prevalence of HF for people over a1111111111 a1111111111 a1111111111 a1111111111 a1111111111

Data extraction
The HF dataset is extracted from the MIMIC-III v1.4. MIMIC-III (Medical Information Mart for Intensive Care III) [21,22] is a large, freely-available database comprising healthrelated data associated with over forty thousand patients who stayed in critical care units of the Beth Israel Deaconess Medical Center between 2001 and 2012. MIMIC-III V1.4 adopted ICD-9 codes. According to ICD-9 codes, 25 types of heart failure were extracted, shown in Table 1.
Under the criterion that age greater than or equal to 18 years, we totally extracted 10311 patients. Starting point is defined as the time of first hospitalized HF patients, and the endpoint is the time when patients were dead or discharge. Then, each patient has a death time, such as 0, 364. Zero means alive. The number 364 means that a patient died 364 days after being admitted to the hospital for the first time. So, we divide HF patient by death time into five categories: survivable patients who are alive in the statistical period, dead within 30 days patients, dead within 180 days patients, dead within 365 days patients and patients died after 365 days. Each group of died patients form a group with those who did not. Therefore, we have four binary experimental classifications. Every group data is a binary value {0, 1}. All patients died were labeled as a positive sample, others is negative sample, shown in Table 2.
Besides we calculate the imbalanced rate (IR) defined in the formula (1). Four datasets of imbalanced rate respectively is 1.7148, 3.0431, 7.9981 and 2.5277. The W365D dataset is the

PLOS ONE
discrete features by one-hot coding. There are gender (index 1 in Table 3), medication (index 3-9), surgery (index 10-13), related diseases (index 14-39) and the feature of Stayed in CCU (index 42). As shown in Table 3, the digit 0 is male and the digit 1 is female for gender. Then according to medicine efficacy, the medications are divided into 7 groups. Those are ACEI, ARB, beta-receptor blockers, CCB, digitalis, diuretic, and nitrates. Next, the surgery contains 4 classes, which are left ventricular assistant device (LVAD), cardiac resynchronization therapy (CRT), automatic implantable cardioverter/defibrillator check (ICD) and heart transplantation. Besides, we summary 26 diseases as the related diseases (see Appendix A in S1 File), such as Cardiac arrhythmias and Cardiomyopathy. The feature of medication, surgery, related diseases and staying in CCU are represented in the same way. The digit 1 means a patient has used one of the medication or has had one of the surgery and so on. On the contrary, the digit 0 means not.
Proposed indicator vector for missing values method. Since the 24 dimensions of laboratory test (see Appendix B in S1 File), heart rate and BMI features contain missing values, it is necessary to fill the missing values. Mean with variance is the most widely used missing value imputation techniques [23]. However a lot of characteristics in the MIMIC-III HF database have very large variance. For example, white cells count in blood (WBC) feature's mean is 17.1002 K/uL and variance is 15.1180. So the mean/variance method is not suitable. Hence, we straightforwardly chose the easiest way to deal with missing values, which the normal range of the characteristics is used to fill missing values, and then the filled value is marked as the padded value. The digit 1 is a flag that the value is missing and has been filled, and the digit 0 means it's a true value. In this approach, the problem of incomplete database can be solved quickly.
A sample processed by filling method is shown in Table 4. The hemoglobin value is missing, so we use the random value in the normal range to fill. Meanwhile we set digit 1 as an indicator value, which points out hemoglobin value is filled and not true value. Heart rate in Table 4 is true value, so the indicator vector is 0. Other features have implications similar to hemoglobin and heart rate.
Since an indicator value is added after all missing features, the final feature dimension is 90 showed in Table 3. To further illustrate the features, Table 5 displays the composition of features. After missingness imputation, we adopt the Z-score normalization method to normalize data for avoiding the influence of outliers and extremes.  [24], CNN has the local connectivity which can make use of the local information, optimize network parameters and structure, reduce model training time and improve performance [25]. Therefore, we use CNN to predict the mortality. Because, the dataset, in our research, is a piece of HF patient characteristic information, so the input is in a 1D format. Therefore, we adopt the 1D convolution. Furthermore, the 1D convolution only convolutes one direction (vertical) of features sequence. It is easy to separate the indicator vector from the indicated feature and hard to capture characteristic information. Therefore, our study uses multiple size kernels to extract key information for better capturing the information of the indicator vector and local correlation.
In conclusion, the deep learning model is indicated in Fig 2. Firstly, the input size is 256 � (1 � 90). 256 is the batch size. And 90 is the feature dimension and the reason has been explained in the section of Proposed indicator vector for missing values method. Secondly, eight convolutional kernels of (1 � 3) size combine with Batch Normalization (BN) and ReLU function as Convolution layer. Thirdly three different convolution groups, 3 � (1 � 3), 3 � (1 � 4) and 3 � (1 � 5) combine with BN, ReLU function and MaxPooling as different kernels layer, which respectively receive the convolution layer learned information. Furthermore, the MaxPooling with 2 units of sampling window is used. Fourthly, stitching layer binds every output from different kernels structure, changing to 256 � 396. Then the multi-head self-attention structure is used to focus on the global and local information from the stitching layer. We set the head in multihead self-attention is two. The reason is discussed in subsection 4.2. Fifthly, the fully connected layer with (396 128) and (128 2) size processes information from the attention mechanism. Finally the output is 256 � (1 � 2) which means we obtain 256 mortality predictions at one time.
Multi-head self-attention mechanism. Attention simulates the human brain processing mechanism. It identifies the target area, through focusing on crucial information and neglecting other information. In this way, the efficiency and accuracy of a model can make great progress [26]. Yingying Zhang et al. [27] proposed that representation in different subspaces likely focuses on different information. All subspaces can enhance the global information. This idea inspires us. There are three different size receptive fields in our model, and each group focuses on different information. After splicing, all information is summarized. The inside of the information is a relatively complete whole, while there is no correlation and interaction between the three cores. Hence, we use the multi-head self-attention mechanism illustrated in Fig 3 to obtain the global information and pay more attention on the important features.
As shown in Fig 3 the x (3 � 132) in stitching layer goes through three different linear layers to generate keys (denoted as K), queries (denoted as Q) and values (denoted as V), which is described in the formula (2). And the self-attention reflects the input from K, Q and v is the same.
Here, (W k , b k ) is a set of parameters about a linear layer also named fully connected layer. And W k is the weight and b k is the bias. The (W q , b q ) and (W v , b v ) is the same as the (W k , b k ).

PLOS ONE
Then, the attention calculates the similarity between Q and K. The similarity reflects the importance of the extracted V, that is, the weight. Then according to input dimension (denoted as d_model), it scales the weights. Next, the attention value is obtained by weighted average, using the softmax function. The self-attention is reflected in Q = K = V. Formula (3) demonstrates the same process. attentionðQ; K; VÞ ¼ softmaxð QK T ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi ffi d model Afterwards, the multi-head self-attention mechanism used different head (number of h) to gain different representations from (Q, K, V). Ultimately, it concatenates the different results through a linear layer.
Where head i is the i-th head. In our research, h is equal to 2. And through the all steps, the output size is 3 � 132.
Focal loss function. In general, the imbalanced problem is a common problem in medical data processing and analysis. Similarly, the class imbalanced problem exists in MIMIC-III shown in Table 2. Therefore, we apply the focal loss function to deal with the problem. Focal loss function is proposed for one-stage detector in image object detection [28]. By reducing the weight of a large number of negative samples in the training samples, focal loss function makes the model focus on the category with fewer samples in the training process. Meanwhile, by reducing the weight of samples that are easy to classify, the accuracy of difficult to classify samples is improved.
The focal loss function is developed from the cross entropy loss function. And the cross entropy is defined: ( where y is {0,1} and denotes the true label in dataset. In this research, label 1 is the dead HF patient. And p�[0,1] is the model prediction. For simplification, the transformation is as follow.
( Hence, the cross entropy loss is defined as follows: From the formula (8), the focal loss function representation is as follows.
In formula (8), the α t and γ are two hyper-parameters. The α t is used to adjust the proportion of positive and negative samples. Moreover, the γ revises the samples which are difficult to separated. In this study, we set α t and γ are 0.25 and 2 respectively.

Training strategy
In first step, we adopt the 5 fold cross-validation to train our model for avoiding the over-fitting and under-fitting. Then, we divide each kind of dataset into a training set, verification set and test set. 5% of the test sets are randomly generated from each kind of dataset. Besides, the training epoch is 120. We used the Root Mean Square prop (RMSprop) optimizer and the initial learning rate is 0.001.
In the model evaluation step, we used seven criterions. There are accuracy (ACC), Positive Prediction Value (PPV) also named Precision, Negative Prediction Value (NPV), Recall, F1 score (F1), and Area Under Positive Rate (AUC), shown in formula (10) to (14). Considering the datasets are imbalanced, the model stability is crucial. Hence, we adopted the micro-average of AUC sensitive to the small samples to reflect stability. Recall In the above formulas, TP is the number of true positive samples, and on the contrary, TN is the number of true negative samples. FP is the number of false-positive samples and FN is the number of false-negative samples.

Mortality prediction for HF patients
In this subsection, we use proposed model to predict HF patients' mortality. First of all, we apply model in the W30D datasets. As shown in Fig 4A, after 80 epochs, the model is gradually stable. The loss decreases from 0.9477 to 0.0417. The ACC maintains at 84.56%, and F1 score rises from 23.01% to 78.69%. The results reveal that it is of capacity that the model distinct dead patients and living patients. Besides, the AUC in Fig 5A is 91.00% declaring the model has good stability. Moreover, micro-average AUC displays 91.00%, which denotes model has the capability to deal with imbalanced data.
Secondly, we measure the model performance on W180D datasets. The loss is constant at 0.0423, through 70 epochs in Fig 4B. Compared with W30D, the imbalanced rate of W180D is higher. Therefore, ACC downs three percent and finally keeps in 82.08%. As the imbalance rate increased, the F1 score drops to 56.33%, as does the ACC. However, the model remains stable. The AUC is 82.00% of both classes, in Fig 5B. As well as, the micro-average AUC reaches 88.00%. From this phenomenon, the model still handles the imbalance datasets well.
Afterwards when the imbalance rate continues to increase in W365D, the loss becomes more volatile. The Fig 4C loss curve appears this phenomenon. The loss disturbances between 0.022 and 0.024. Along with this comes that the AUC is 75.00% of both classes in Fig 5C, which indicates the decreasing stability. Moreover, there's a huge difference between the micro-average AUC and macro-average AUC. The micro-average AUC is 91.00%. But the macro-average AUC is 75.00%. This difference indicates the model has the problem to handle the dataset with 7.988 imbalance rate. Then, although the ACC is higher than others, reaching 88.56%, the F1 score merely represented at 34.35%. This experiment proves the prediction tends to the category with a large number.

PLOS ONE
Ultimately, we discuss A365D dataset. The loss maintains around 0.050 in Fig 4D. It is slightly higher than others. This leads to the ACC is 76.07% and F1 is 46.34%. The A365D records HF patients who died after 1 year. However in A365D, 66.19% (1110/1677) of HF patients went 2 years or more. Therefore, there are more similarities in the characteristics of patients who died after one year and those who not died during the statistical period, comparing with other datasets. It declares the reason why the model has a lower ACC on A365D. In addition, AUC reaches of 72% both of classes. Fig 5D reports the micro-average AUC is 93%, explaining the model stability.

Model performance compared with other methods
In order to prove the validity, we compare DLS-MSM with six comparison models, which is representative and widely applied in the medical field or bio-medical field. There are Support Vector Machine (SVM) [29], Multi-layer Perceptron (MLP) [24], Logistic Regression (LR), Random Forest (RF) [30], Light Gradient Boosting Machine (LigthGBM, LGB) [31] and K-Nearest Neighbor (KNN). Table 6 showed comparison. All the decimals have been converted into percentages. Besides, for highlighting the best results for each metric have been bolded. The effects both are from the test set.
Because W30D dataset is enough and has small imbalance ratio, the models all perform well. Specifically, RF obtains 86.31% ACC, and is 1.75% higher than DLS-MSM. In addition, the F1 score is the harmonic mean of PPV and Recall. So, F1 score make progress by 12.21%, 21.12% and 16.565%, respectively, in W180D, W365D and A365D. This phenomenon indicates that our model is preferable to deal with the dataset. As shown in Table 6, F1 score of the SVM and RF

PLOS ONE
model are null in W365D dataset. This indicates those models have difficulty handling imbalanced data. However, in a large number of negative samples, DLS-MSM can accurately identify the dead HF patients. Moreover, the NPV in the four algorithms are basically flat at 99.88%, which indicates algorithms applied in W365D dataset are hard to solve the imbalance problem. In this point, the DLS-MSM with highest PPV and F1 score is much better than other algorithms.
On the one hand, from the perspective of AUC, the model remains relatively stable. Because of the higher imbalance ratio in W365D, the positive samples are virtually ignored in LGB, from higher Recall and lower F1. Hence, LGB is hard to be adopted in the mortality prediction. The same explanation applies to the RF, which is to 88.70% AUC in A365D. On the other hand, the micro-average of AUC proves that DLS-MSM is effective in dealing with imbalance problem. The micro-average of AUC promotes around 1.5%.

The effect of indicator vector
In order to verify the validity of the indicator vector, we carry out a comparative experiment. The dataset without indicator vector is the control group. So the dimension of input in the experiment is changed to 66. And the dimension of attention layer and linear layer changes to (2,96) and (288,128) respectively.  Table 7 displays the comparative experiment, which establishes the validity of the indicator vector. All the model metrics with indicator vector are higher than without, except W365D. Because W365D dataset is the most lopsided dataset and F1 is lower than the model with indicator vector. The training model ignores the influence of the positive samples, in W365D. Adding indicator helps model to identify more positive samples from true positive samples. Furthermore the NPV value is higher than not adding the indicator vector which indicates the model can comprehensively handle the positives and negative samples.

The effect of using attention
Attention plays an import role in the CNN framework. Therefore, we compare CNN framework with and without attention, showing in Table 8. Because PPV value and Recall are mutually impacted. Recall in DLS-MSM with attention mechanism is higher than without attention. On the contrary, PPV value in without attention model is higher than with attention. This result involves the problem of recognition rate. The higher Recall means that the model can distinguish true positive samples. The NPV value reflects that the DLS-MSM with attention can accurately identify negative samples. In addition, the model with attention is more stable, which is displayed by the AUC and Micro-average AUC. All in all the proposed DLS-MSM system with attention is better than without attention except the recall in W30D database.

Effect of attention in different locations
The attention mechanism plays different roles in different locations of the system structure. Therefore, we involve two distinct experiments to study the effective locations. According to the principle of attention mechanism, we choose the position one between convolution layer and different kernels layer and the position two between the different kernel layers and stitching layer shown in Fig 2. The sizes of the attention mechanism are respectively (2, 90), (2,44).
As shown in Table 9, recall of W180D dataset and A365D dataset in position 1 is 5.64% and 26.87% higher than position 2 respectively. However, the PPV is not higher. Hence

Effect of different number of heads
The multi-head self-attention mechanism can have different heads, which may influence the system performance. From this point we set the head 2, 3, 6, 11 and 12 for analyzing. The specific results are explained in Fig 6. Besides, the more the number of the head increase the model complexity and influence the training time. Thus, we test the average time in different heads, shown in Fig 7. The unit is seconds. The result in the upper left of Fig 6 is W30D. Recall fluctuates widely. When the number of the head reaches 12, the recall is 85.29. But there is no significant increase in F1 relative to the other results. The high recall rate but low F1 indicates that more real negative samples are predicted to be true. The number of six heads is similar to this. And the number of 3 heads and 11 heads behaves poorly. Recall is only 72.48 and 69.06 severally. Hence, in the W30D the number of 2 head has a better manifestation. Besides it spends less time training. Other datasets in Fig  6 give us a more obvious different model performance. Comprehensive analysis indicates that in our study the number of 2 head in attention has best consequence and less time consumption.

Important feature ranking based Deep SHAP
Neural network is often thought of as a black box, so how to interpret the complex models become a hot topic study. We adopt Deep SHAP [32] to interpret our deep learning model and rank the feature. Deep SHAP is a deep learning interpreted model combining DeepLIFT theory and SHAP value. DeepLIFT is used to interpret deep learning model by calculating the weight to each feature of the input in the backpropagation [33]. SHAP (SHapley Additive exPlanations) gives each feature a unique importance score. Scott M. Lundberg and Su-In Lee   Table 10.
Heart failure and severe hypoxia and ischemia can be combined with severe arrhythmia, especially the occurrence of ventricular fibrillation. Clinical symptoms include loss of consciousness, convulsions, respiratory arrest and even death. So the respiratory failure is the important features. The RAAS system is activated when heart failure occurs. ARB drugs can inhibit the RAAS system and myocardial remodeling, delay the progression of heart failure. Calcium channel blockers (CCB) can reduce calcium in myocardial cells concentration to improve myocardial active diastolic function and lower blood pressure.
From these clinical medical conclusions, it can be shown that the features extracted, in Table 10, by our system are scientific and effective. Therefore, our system can support to doctors in prognostic treatment.

Conclusion
In our study, a CNN deep learning model based on multi-head self-attention is applied to the mortality prediction system for prognostic HF patients. The system can distinguish between four categories of death, that is, death within 30 days, death within 180 days, death within 365 days and death after 365 days. First, we proposed that the indicator vector indicates the value is true or be filled. Then a multi-head self-attention is introduced to CNN deep learning model. Finally, the Focal loss function is applied to overcome the imbalance. The results from experiment display the idea is feasible. The whole system is effective to predict the mortality. In the end in order to explain the system, we use the Deep SHAP method to make an essential feature reasonable rank.
Supporting information S1 Data. SQL code. The code is about how to extract data from the MIMIC-III. (7Z)